/*
 * Copyright (c) 2020 - present, Inspur Genersoft Co., Ltd.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.iec.edp.caf.rpc.remote.grpc;

import com.google.protobuf.ByteString;
import io.grpc.*;
import io.grpc.stub.StreamObserver;
import io.iec.edp.caf.rpc.api.channel.RpcChanelAsynCallBack;
import io.iec.edp.caf.rpc.api.common.RpcChannelType;
import io.iec.edp.caf.rpc.api.grpc.GrpcInvokeServiceGrpc;
import io.iec.edp.caf.rpc.api.grpc.GrpcRequest;
import io.iec.edp.caf.rpc.api.grpc.GrpcResponse;
import io.iec.edp.caf.rpc.api.grpc.GrpcVariable;
import io.iec.edp.caf.rpc.api.support.ConstanceVarible;
import io.iec.edp.caf.rpc.api.support.RpcThreadCacheHolder;
import io.iec.edp.caf.rpc.api.support.RpcFiltersContainer;
import io.iec.edp.caf.rpc.api.channel.RpcAbstractChannel;
import io.iec.edp.caf.rpc.api.channel.RpcChannel;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import lombok.var;

import java.io.*;
import java.lang.reflect.InvocationTargetException;
import java.net.URL;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;

/**
 * gRPC Channel
 */
@Slf4j
public class GrpcChannel extends RpcAbstractChannel {

    /**
     * gRPC Client
     */
    private GrpcInvokeServiceGrpc.GrpcInvokeServiceBlockingStub blockingStub;

    private GrpcInvokeServiceGrpc.GrpcInvokeServiceStub serviceStub;

    /**
     * gRPC Server
     * This server is used to start Netty.
     * Use static to ensure there's only one server.
     */
    private static Server server = null;

    /**
     * Size of header and msg.
     */
    private final static int maxHeaderSize = 64*1024;
    private final static int maxMsgSize = 64*1024*1024;

//    private List<BindableService> grpcServices = new ArrayList<>();

    ManagedChannel managedChannel = null;



    public GrpcChannel(){
        super();
    }
    /**
     * @param host gRPC服务的主机名
     * @param port gRPC服务的端口
     */
    private GrpcChannel(String host, int port) {
        this.host = host;
        this.port = port;
        ManagedChannelBuilder<?> channelBuilder = GrpcChannelBuilder.getChannelBuilder(host,port);
        registerClientInterceptors(channelBuilder);
        managedChannel = channelBuilder.build();
        blockingStub = GrpcInvokeServiceGrpc.newBlockingStub(managedChannel);
        serviceStub = GrpcInvokeServiceGrpc.newStub(managedChannel);
    }

    /**
     * register client filter
     * @param channelBuilder channelBuilder
     */
    private void registerClientInterceptors(ManagedChannelBuilder<?> channelBuilder){
        channelBuilder.intercept(new GrpcClientInterceptor());
    }

    /**
     * register server filter
     * @param serverBuilder serverBuilder
     */
    private void registerServerInterceptors(ServerBuilder<?> serverBuilder){
        serverBuilder.intercept(new GrpcServerInterceptor());
    }


    /**
     * Channel type
     * different Channel decision the protocol
     * @return channelType
     */
    @Override
    public String channelType() {
        return RpcChannelType.GRPC.getValue();
    }

    /**
     * build a CAFChannel object
     * @param host 域名/ip
     * @param port 端口号
     * @return RpcChannel
     */
    @Override
    public RpcChannel buildChannel(String host, int port) {
        return buildChannel(host,port,null);
    }

    @Override
    public RpcChannel buildChannel(String host, int port, String filterType) {
        this.filterType = filterType;
        return new GrpcChannel(host,port);
    }

    @Override
    public RpcChannel buildChannel(URL url, String filterType) {
        String host = url.getHost();
        int port = url.getPort();
        return buildChannel(host,port,filterType);
    }

    @Override
    public String serviceDiscover(String serviceUnit) {
        return super.serviceDiscover(serviceUnit + "-grpc");
    }

    @Override
    public void addHeaders(Map<String, String> context) {
        if(context.keySet().size()>0){
//            Metadata metadata = new Metadata();
            context.forEach((key,val)->{
                RpcThreadCacheHolder.setValue(key,val);
//                metadata.put(Metadata.Key.of(key,Metadata.ASCII_STRING_MARSHALLER),val);
            });
//            blockingStub = MetadataUtils.attachHeaders(blockingStub,metadata);
        }
    }


    /**
     * start gRPC Netty Server
     */
    @SneakyThrows
    @Override
    public void startServer(Map<String,Object> args) {
        Integer serverPort = (Integer) args.get(GrpcVariable.SERVER_PORT);
        //todo maybe is services
        BindableService grpcService = (BindableService) args.get(GrpcVariable.SERVER_SERVICE);
        if(serverPort==null||grpcService==null)
            throw new RuntimeException("Can't recognize arguments of grpc server.");

        ServerBuilder<?> serverBuilder = ServerBuilder.forPort(serverPort).maxInboundMetadataSize(maxHeaderSize).maxInboundMessageSize(maxMsgSize);
        // 将具体实现的服务添加到gRPC服务中
        serverBuilder.addService(grpcService);

        registerServerInterceptors(serverBuilder);

        server = serverBuilder.build();
        server.start();

        //log.error("GrpcService has already start on port: "+serverPort);
    }

    /**
     * stop gRPC Netty Server
     */
    @Override
    protected void stopServer() {
        synchronized(server){
            if(server!=null&&!server.isShutdown())
                server.shutdown();
        }
    }

    @Override
    public void startClient() {

    }

    /**
     * stop gRPC Client
     */
    @Override
    public void stopClient(){
        this.managedChannel.shutdown();
    }

    @Override
    public String invokeRemoteService(String serviceId, String version, LinkedHashMap<String, String> parameters, HashMap<String, String> context) throws IllegalAccessException, ClassNotFoundException, InstantiationException, IOException, NoSuchMethodException, InvocationTargetException {
        try{
            GrpcRequest req = GrpcRequest.newBuilder()
                    .setServiceId(serviceId)
                    .setVersion(version)
                    .putAllParams(parameters)
                    .build();
            GrpcResponse resp = blockingStub.grpcRemoteInvoke(req);

            return resp.getMessage();
        } catch (Exception ex){
            throw ex;
        }

    }

    @Override
    public InputStream invokeRemoteServiceStream(String serviceId, String version, LinkedHashMap<String, String> parameters, HashMap<String, String> context) throws IllegalAccessException, ClassNotFoundException, InstantiationException, IOException, NoSuchMethodException, InvocationTargetException {
        try{
            GrpcRequest req = GrpcRequest.newBuilder()
                    .setServiceId(serviceId)
                    .setVersion(version)
                    .putAllParams(parameters)
                    .build();
            var resp = blockingStub.grpcRemoteInvokeFromStream(req);

            ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
            while (resp.hasNext()) {
                var content = resp.next().getContent();
                if(content!=null){
                    var cont = content.toByteArray();
                    outputStream.write(cont);
                }
            }

            ByteArrayInputStream swapStream = new ByteArrayInputStream(outputStream.toByteArray());
            return swapStream;
        }catch (Exception ex){
            throw ex;
        }

    }

    @Override
    public void invokeRemoteServiceStreamAsyn(String serviceId, String version, LinkedHashMap<String, String> parameters, HashMap<String, String> context, RpcChanelAsynCallBack callback) throws IllegalAccessException, ClassNotFoundException, InstantiationException, IOException, NoSuchMethodException, InvocationTargetException {
        try{
            GrpcRequest req = GrpcRequest.newBuilder()
                    .setServiceId(serviceId)
                    .setVersion(version)
                    .putAllParams(parameters)
                    .build();
            serviceStub.grpcRemoteInvokeFromStream(req,new StreamFromObserver(callback));
        }catch (Exception ex){
            throw ex;
        }

    }

    @Override
    public String invokeRemoteServiceStream(InputStream inputStream, String serviceId, String version, LinkedHashMap<String, String> parameters, HashMap<String, String> context) throws IllegalAccessException, ClassNotFoundException, InstantiationException, IOException, NoSuchMethodException, InvocationTargetException {
        try{
            StreamObserver<GrpcRequest> streamObserver = this.serviceStub.grpcRemoteInvokeToStream(new StreamToObserver());

            // upload file as chunk
            byte[] bytes = new byte[4096];
            int size;
            while ((size = inputStream.read(bytes)) > 0){
                //build request parameter
                GrpcRequest req = GrpcRequest.newBuilder()
                        .setServiceId(serviceId)
                        .setVersion(version)
                        .putAllParams(parameters)
                        .setContent(ByteString.copyFrom(bytes, 0 , size))
                        .build();

                streamObserver.onNext(req);
            }

            // close the stream
            inputStream.close();
            streamObserver.onCompleted();
            return "";
        }catch (Exception ex){
            throw ex;
        }

    }

    class StreamToObserver implements StreamObserver<GrpcResponse> {

        @Override
        public void onNext(GrpcResponse fileUploadResponse) {
            log.debug(fileUploadResponse.getMessage());
        }

        @SneakyThrows
        @Override
        public void onError(Throwable throwable) {
            log.error(throwable.getMessage(),throwable);
            throw throwable;
        }

        @Override
        public void onCompleted() {

        }

    }

    class StreamFromObserver implements StreamObserver<GrpcResponse> {

        private RpcChanelAsynCallBack callBack;
        public StreamFromObserver(RpcChanelAsynCallBack callBack){
            this.callBack = callBack;
        }
        @Override
        public void onNext(GrpcResponse value) {
            this.callBack.onCallBack(value.getContent());
        }

        @Override
        public void onError(Throwable cause) {
            log.error("Error occurred, cause {}", cause.getMessage());
        }

        @Override
        public void onCompleted() {
            log.info("invokeRemoteServiceStreamAsyn completed");
        }
    }


    /**
     * gRPC Client interceptor
     */
    class GrpcClientInterceptor implements ClientInterceptor {

        private int times = 0;

        @Override
        public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
            return new ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>(next.newCall(method, callOptions)) {

                //可以拦截不同阶段
                @Override
                public void start(Listener<RespT> responseListener, Metadata headers) {
                    //here to add client requset headers
                    Map<String,Object> reqContext = new HashMap<>();
                    var channelFilterType = (filterType!=null && !"".equals(filterType))?filterType:ConstanceVarible.RPC_FILTER_MODE_DEFAULT;
                    //invoke global iGix default filters
                    var filters = RpcFiltersContainer.getClientFilter().get(channelFilterType);
                    if(filters!=null){
                        filters.forEach(x->{
                            x.doOutFilter(reqContext);
                        });
                    }


                    reqContext.keySet().forEach(x->{
                        if(!x.equalsIgnoreCase(ConstanceVarible.RPC_HEADER)){
                            headers.put(Metadata.Key.of(x, Metadata.ASCII_STRING_MARSHALLER),reqContext.get(x).toString());
                        }
                    });

                    super.start(new ForwardingClientCallListener.SimpleForwardingClientCallListener<RespT>(responseListener) {
                        @Override
                        public void onHeaders(Metadata headers) {
                            /**
                             * if you don't need receive header from server, you can
                             * use {@link io.grpc.stub.MetadataUtils#attachHeaders}
                             * directly to send header
                             */
                            //here to get server response headers
                            //here to get client requset headers 从header中还原信息
                            Map<String, Object> repContext = new HashMap<>();
                            headers.keys().forEach(x->{
                                repContext.put(x,headers.get(Metadata.Key.of(x, Metadata.ASCII_STRING_MARSHALLER)));
                            });

                            //invoke global iGix default filters
                            var channelFilterType = (filterType!=null && !"".equals(filterType))?filterType:ConstanceVarible.RPC_FILTER_MODE_DEFAULT;
                            var filters = RpcFiltersContainer.getClientFilter().get(channelFilterType);
                            if(filters!=null){
                                filters.forEach(x->{
                                    x.doInFilter(repContext);
                                });
                            }

                            super.onHeaders(headers);
                        }
                    }, headers);
                }

                @Override
                public void request(int numMessages) {
                    super.request(numMessages);
                }

                @Override
                public Attributes getAttributes() {
                    return super.getAttributes();
                }

                @Override
                public void sendMessage(ReqT message) {
                    super.sendMessage(message);
                }

            };
        }
    }

    /**
     * gRPC server interceptor
     */
    class GrpcServerInterceptor implements ServerInterceptor {

        @Override
        public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
            try {
                //here to get client requset headers 从header中还原信息
                Map<String, Object> reqContext = new HashMap<>();
                headers.keys().forEach(x->{
                    reqContext.put(x,headers.get(Metadata.Key.of(x, Metadata.ASCII_STRING_MARSHALLER)));
                });

                //invoke global iGix default filters
                var channelFilterType = (filterType!=null && !"".equals(filterType))?filterType:ConstanceVarible.RPC_FILTER_MODE_DEFAULT;
                var filters = RpcFiltersContainer.getServerFilter().get(channelFilterType);
                if(filters!=null){
                    filters.forEach(x->{
                        x.doInFilter(reqContext);
                    });
                }

            }catch (Throwable e){
                call.close(Status.CANCELLED.withDescription(e.getMessage()),headers);
                throw e;
            }

            //服务端写回参数(服务端未结束)
            ServerCall<ReqT, RespT> serverCall = new ForwardingServerCall.SimpleForwardingServerCall<ReqT, RespT>(call) {
                @Override
                public void sendHeaders(Metadata headers) {
                    //here to add server response headers
                    //invoke global iGix default filters
                    Map<String, Object> repContext = new HashMap<>();
                    var channelFilterType = (filterType!=null && !"".equals(filterType))?filterType:ConstanceVarible.RPC_FILTER_MODE_DEFAULT;
                    var filters = RpcFiltersContainer.getServerFilter().get(channelFilterType);
                    if(filters!=null){
                        filters.forEach(x->{
                            x.doOutFilter(repContext);
                        });
                    }

                    repContext.keySet().forEach(x->{
                        if(!x.equalsIgnoreCase(ConstanceVarible.RPC_HEADER)){
                            headers.put(Metadata.Key.of(x, Metadata.ASCII_STRING_MARSHALLER),repContext.get(x).toString());
                        }
                    });
                    super.sendHeaders(headers);
                }
            };
            //执行服务端业务逻辑
            return new ForwardingServerCallListener.SimpleForwardingServerCallListener<ReqT>(next.startCall(serverCall,headers)) {

                @Override
                public void onMessage(ReqT message) {
                    super.onMessage(message);
                }

                /**
                 * 请求结束 处理服务端返回拦截器清理上下文信息
                 */
                @Override
                public void onComplete() {
                    super.onComplete();
                }
            };
        }
    }

}
